[skyrl] Add /sample endpoint to RemoteInferenceClient following Tinker API#1396
[skyrl] Add /sample endpoint to RemoteInferenceClient following Tinker API#1396pcmoritz merged 8 commits intoNovaSky-AI:mainfrom
Conversation
- Add RemoteInferenceClient.sample() mapping Tinker-style sample requests to the vLLM /inference/v1/generate endpoint - Support n completions, logprobs, and configurable sampling params - Add unit tests (n=1, n=2, session_id routing) - Add GPU integration tests (sample, sample_multiple, sample_deterministic) - Simplify _force_close_connector to use transport.close() directly
55bc8e7 to
929e25b
Compare
There was a problem hiding this comment.
Code Review
This pull request introduces a new sample method to RemoteInferenceClient to support the Tinker API, along with corresponding unit tests and updates to the mock inference server. I have provided feedback regarding the optimization of the _PARAM_MAP constant, the need for a test case covering session_id routing, and a correction for the num_choices logic in the mock server.
| def test_client_sample_deterministic(vllm_server: InferenceEngineState): | ||
| """Test that sample with seed + temperature=0 is deterministic across calls.""" | ||
| client = vllm_server.client | ||
| token_ids = _get_test_token_ids(MODEL_QWEN2_5) | ||
| params = {"temperature": 0.0, "max_tokens": 32, "seed": 42} | ||
|
|
||
| result1 = asyncio.run(client.sample(_build_sample_payload(token_ids, num_samples=1, sampling_params=params))) | ||
| result2 = asyncio.run(client.sample(_build_sample_payload(token_ids, num_samples=1, sampling_params=params))) | ||
|
|
||
| assert result1["sequences"][0]["tokens"] == result2["sequences"][0]["tokens"] |
There was a problem hiding this comment.
The pull request description mentions adding a unit test for session_id routing for the sample method, but it seems to be missing from the submitted tests. Please consider adding a test case that utilizes the session_id parameter in _build_sample_payload to verify that session-based routing works as expected for the new endpoint.
There was a problem hiding this comment.
Leaving the arg in for _build_sample_payload since we may want to test it in the future. I'm not sure how to test session based routing in our current setup, so leaving for now.
…client.py revert change Co-authored-by: devin-ai-integration[bot] <158243242+devin-ai-integration[bot]@users.noreply.github.com>
|
|
||
| # Transform response choices → sequences | ||
| sequences = [] | ||
| logger.info("num choices: %d", len(response.get("choices", []))) |
There was a problem hiding this comment.
Always logging with info here is probably a little too verbose, right?
There was a problem hiding this comment.
Yes, I put it in for debugging originally. It shouldn't be in & I removed it
| return { | ||
| "type": "sample", | ||
| "sequences": sequences, | ||
| "prompt_logprobs": None, |
There was a problem hiding this comment.
Going forward, we might want / need to support this :)
There was a problem hiding this comment.
Yes! Next PR will include prompt_logprobs but I need to check how they handle prompt logprobs for vision to make sure we handle that
| tinker_params = body.get("sampling_params", {}) | ||
|
|
||
| # Flatten prompt chunks → token IDs | ||
| token_ids = [tok for chunk in prompt.get("chunks", []) for tok in chunk.get("tokens", [])] |
There was a problem hiding this comment.
This will need adaptation for multi-modal inputs going forward, right?
There was a problem hiding this comment.
Yes, this will have to be the token concatenation we talked about, so it will get replaced.
Replace asyncio.run() with await in test_client_sample, test_client_sample_multiple, and test_client_sample_deterministic. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Add
/sampleAPI toRemoteInferenceClientThis PR adds the tinker compatible
/sampleAPI toRemoteInferenceClienton the new inference server codepath, addressing #1286 .Changes
RemoteInferenceClient.sample()method that maps Tinker-style sample requests to the vLLM/inference/v1/generateendpoint, supportingncompletions, logprobs, and configurable sampling params (temperature, top_k, top_p, seed, stop tokens, etc.)Tests
TestSample) covering n=1, n=2, and multi-chunk promptstest_client_sample,test_client_sample_multiple,test_client_sample_deterministic) validating end-to-end generation against a live vLLM server